package LiuYun

import spinal.core._
import spinal.lib._

import scala.collection.mutable.ArrayBuffer
class ExCsr extends Bundle{
    val tid :UInt = UInt(32 bits)
    val tv  :UInt = UInt(64 bits)
    val vppn:UInt = UInt(19 bits)
    def collect(csr:Csr):Unit ={
        tid := csr.tid.tid
        tv  := csr.stable_timer
        vppn := csr.tlbehi.vppn
    }
}

class TlbControl extends Bundle{
    val srch:Bool = Bool()
    val rd  :Bool = Bool()
    val wr  :Bool = Bool()
    val fill:Bool = Bool()

    def setEmpty:this.type = {
      srch := False
      rd := False
      wr := False
      fill := False
      this
    }
}

class StageEx1 extends Component{
    val io = new Bundle{
        val ic    :PipeCtrl = slave(new PipeCtrl)
        val idat  :SDatEx1  = in  (new SDatEx1)
        val oc    :PipeCtrl = master(new PipeCtrl)
        val odat  :SDatEx2  = out (new SDatEx2)
        val dreq  :DataReq   = master(new DataReq)
        val fw    :Forwarding = out(new Forwarding)
        val tlbrw :TlbRW      = master(new TlbRW )
        val invtlb:InvTlb     = out(new InvTlb)
        val div   :DivI       = master(new DivI)
        val csr   :ExCsr      = in(new ExCsr)
        val csrrw :CsrRW      = master(new CsrRW)
        val exbus :CoreCancel = in  (new CoreCancel)
        val brbus :BrBus      = out(new BrBus)
        val btb   :BTBUpdate  = master(new BTBUpdate)
        val tlbfill_index     = in(UInt(LISA.w_tlbidx.bits)) //difftest
        val csr_estat :UInt    = in(LISA.GPR)
        val addrtrans_ok :Bool = out(Bool()) //check it in ex2, block if not ok
    }
    io.addrtrans_ok := io.dreq.translate_ok
    val cancel:Bool = io.exbus.cancel
    val data:SDatEx2 = new SDatEx2 assignedFrom io.idat
    val pipe = new StageBuffer[SDatEx2](data)
    pipe.update(io.ic,io.oc,cancel)
    io.odat := pipe.data
    val I = Instructions

    val wait_empty:ArrayBuffer[Bool] = new ArrayBuffer[Bool]
    val need_empty:Bool = Bool()
    val need_wait_tlb:Bool = RegInit(False)
    val tlb_roll:Bool = RegInit(False)
    val tlb_working = Bool()
    pipe.ready := (io.ic.empty || !need_empty) && !tlb_working
    val valid_ok:Bool = io.ic.valid && !io.ic.ex

    io.fw.valid := io.ic.valid
    io.fw.grwr  := data.grwr
    io.fw.block := io.idat.itype.is_div || io.idat.itype.is_mul || io.idat.itype.is_mem
    io.fw.dest  := data.dest
    io.fw.value := data.value

    val op:UInt = io.idat.subcode
    val ui12:UInt = io.idat.inst(10, 12 bits)
    val ui14:UInt = io.idat.inst(10, 14 bits)
    val ui5 :UInt = io.idat.inst(10,  5 bits)
    val si12:SInt = ui12.asSInt
    val si14:SInt = ui14.asSInt
    val si20:SInt = io.idat.inst(5, 20 bits).asSInt
    val imm :Bool = io.idat.itype.imm
    val link:Bool = io.idat.itype.link
    val ll:Bool = io.idat.itype.is_ll
    val sc:Bool = io.idat.itype.is_sc
    val sign:Bool = io.idat.itype.sign
    val link_to_ra:Bool = io.idat.inst.asBits.takeHigh(6) === B"01_0101" //|| io.idat.inst.asBits.takeHigh(6) === B"01_0011" && io.idat.inst.asBits.takeLow(5) === B"00001" //BL || JIRL && rd == ra
    val is_return:Bool = io.idat.inst.asBits === B"01_0011_000000_00000_00000_00001_00000"
    val timer = new Area{
        val res_id:UInt = io.csr.tid
        val res_vh:UInt = io.csr.tv(63 downto 32)
        val res_vl:UInt = io.csr.tv(31 downto  0)
        val res:UInt = I.select(op,"alu.t",Map(
            "id"->res_id,
            "vh"->res_vh,
            "vl"->res_vl)
        )
    }
    val seqpc:UInt = io.idat.pc + U(4)
    val is_alu:Bool = io.idat.itype.is_alu
    val alu_a :UInt = Mux(is_alu,io.idat.a,U(0))
    val alu_b :UInt = Mux(is_alu,io.idat.b,U(0))
    val arith = new Area{
        val pcadd:Bool = I.isSubOpCode(op,"alu.a","pcadd")
        val a:SInt = alu_a.asSInt
        when(pcadd){
            a := io.idat.pc.asSInt
        }
        val b:SInt = alu_b.asSInt
        when(imm){
            b := si12.resize(LISA.w_gpr)
        }
        when(pcadd){
            b := si20 << 12
        }
        val ea:SInt = (a.msb && sign).asSInt @@ a
        val eb:SInt = (b.msb && sign).asSInt @@ b
        val neg:Bool = I.isSubOpCode(op,"alu.a","sub") || I.isSubOpCode(op,"alu.a","slt")
        val slt:Bool = I.isSubOpCode(op,"alu.a","slt")
        val ec:SInt = ea + Mux(neg,~eb,eb) + (S(0, 32 bits) @@ neg.asSInt)
        val res_slt:UInt = U(0, 31 bits) @@ ec.msb.asUInt
        val res_a_s:UInt = ec(a.bitsRange).asUInt
        val res:UInt = Mux(slt,res_slt,res_a_s)
    }
    val shift = new Area{
        val a:UInt = alu_a
        val b:UInt = alu_b(4 downto 0)
        when(imm){
            b := ui5
        }
        val l:Bool = I.isSubOpCode(op,"alu.s","l")
        val sa:UInt = Mux(l,b,~b)
        val ge:UInt = U("1"*a.getWidth) |<< sa
        val gt :UInt = ge |<< 1
        val msb:Bool = sign && a.msb
        val mask_l:UInt = ge
        val mask_r:UInt = ~gt
        val mask_s:UInt = Mux(msb,gt,U(0))
        val mask:UInt = Mux(l,mask_l,mask_r)
        val a_l:UInt = a
        val a_r:UInt = a.rotateLeft(1)
        val raw:UInt = Mux(l,a_l,a_r).rotateLeft(sa)
        val res:UInt = raw & mask | mask_s
    }
    val logic = new Area{
        val a:UInt = alu_a
        val b:UInt = alu_b
        when(imm){
            b := ui12.resize(LISA.w_gpr)
        }
        val res_or :UInt = a | b
        val res_xor:UInt = a ^ b
        val res_and:UInt = a & b
        val res_nor:UInt = ~(a|b)
        val res_lui:UInt = si20.asUInt << 12
        val res:UInt = I.select(op,"alu.l",Map(
            "nor"->res_nor,
            "and"->res_and,
            "or" ->res_or ,
            "xor"->res_xor,
            "lui"->res_lui)
        )
    }
    val value:UInt = I.select(op,"alu",Map(
        "t" -> timer.res,
        "a" -> arith.res,
        "s" -> shift.res,
        "l" -> logic.res)
    )
    data.value := value
    when(link){
        data.value := seqpc
    }
    when(sc){
        data.value := U(0,31.bits)@@io.csrrw.llbit
    }
    val idle:Bool = valid_ok && I.isInst(io.idat.inst, "IDLE")
    wait_empty.append(idle)
    val csr = new Area{
        val valid:Bool = valid_ok && io.idat.itype.is_csr
        val xchg:Bool = valid && I.isSubOpCode(op,"csr","x")
        val wr  :Bool = valid && I.isSubOpCode(op,"csr","w")
        val rd  :Bool = valid && I.isSubOpCode(op,"csr","r")
        val mayw:Bool = xchg || wr || valid_ok && (ll || sc)
        val we  :Bool = mayw && io.ic.empty
        // set to reserved id when no R/W
        io.csrrw.a  := Mux(valid,io.idat.inst(10, 14 bits),U(0x1000))
        io.csrrw.wd := Mux(xchg || wr, io.idat.b, U(0))
        io.csrrw.wm := ( Mux(xchg, io.idat.a, U(0))
                       | Mux(wr  , U("1"*32), U(0)))
        io.csrrw.en := we
        when(valid){
            data.value := io.csrrw.rd
        }
        wait_empty.append(mayw)
    }
    val tlb = new Area{
        val may :TlbControl = new TlbControl
        val may_valid = Bool()
        val nop :TlbControl = new TlbControl setEmpty
        val may_inv:Bool = valid_ok && io.idat.itype.is_invtlb
        val inv  :Bool = io.ic.empty && may_inv && !need_wait_tlb
        need_wait_tlb.riseWhen(io.invtlb.en || io.tlbrw.srch || io.tlbrw.rd)
        need_wait_tlb.fallWhen(io.tlbrw.finish)
        tlb_roll.riseWhen(io.invtlb.en)
        tlb_roll.fallWhen(io.tlbrw.finish)
        tlb_working := ((inv || io.tlbrw.srch || io.tlbrw.rd) || need_wait_tlb) && !io.tlbrw.finish

        may.rd   := valid_ok && Instructions.isInst(io.idat.inst, "TLBRD")
        may.fill := valid_ok && Instructions.isInst(io.idat.inst, "TLBFILL")
        may.wr   := valid_ok && Instructions.isInst(io.idat.inst, "TLBWR")
        //io.tlbrw.inv  := inv
        // for timing
        may.srch := valid_ok && io.idat.itype.is_tlbsrch
        io.tlbrw.valid := (may.rd || may.fill || may.wr || may.srch || inv) && !need_wait_tlb //when fire and wait, valid should down
        // TLBSRCH also flush pipeline to prevent CSR hazard
        val mayw   :Bool = may.fill || may.wr || may_inv || may.srch || may.rd
        //inv should roll after finish
        when(io.tlbrw.valid){
            pipe.ready := io.ic.empty && io.tlbrw.ready && !tlb_working
        }
        //io.tlbrw := Mux(io.ic.empty && io.tlbrw.ready,may,nop)
        when(io.ic.empty && io.tlbrw.ready){
            io.tlbrw.assignSomeByName(may)
            //io.tlbrw.valid := may_valid
        }.otherwise{
            io.tlbrw.assignSomeByName(nop)
            //io.tlbrw.valid := False
        }
        io.invtlb.en   := inv
        io.invtlb.op   := io.idat.inst(0, 5 bits)
        io.invtlb.asid := io.idat.a(0, LISA.w_asid bits)
        wait_empty.append(mayw)
        io.invtlb.may := may_inv
        io.invtlb.va  := io.idat.b
    }

    val is_cacop = Bool()
    val bru = new Area{
        val valid:Bool = valid_ok && io.idat.itype.is_bru
        val j:Bool = I.isSubOpCode(op,"bru","j")
        val b:Bool = I.isSubOpCode(op,"bru","b")
        val is_b26:Bool = b && I.isSubOpCode(op,"bru.b","26")
        val is_b16:Bool = b && I.isSubOpCode(op,"bru.b","16")
        val imm16:UInt = io.idat.inst(25 downto 10)
        val imm26:UInt = io.idat.inst( 9 downto  0) @@ imm16
        val offs16:SInt = imm16.asSInt.resize(26)
        val offs26:SInt = imm26.asSInt
        val base:UInt = U(0)
        val offs:UInt = U(0)
        when(valid){
            base := Mux(j,io.idat.a,io.idat.pc)
            offs := (Mux(is_b26, offs26(25 downto 16), offs16(25 downto 16)).resize(14)
                    @@ offs16(15 downto 0)
                    @@ S(0, 2 bits)).asUInt
        }
        val ea:SInt = ((io.idat.a.msb && sign).asUInt @@ io.idat.a).asSInt
        val eb:SInt = ((io.idat.b.msb && sign).asUInt @@ io.idat.b).asSInt
        val is_eq :Bool = ea === eb
        val is_ne :Bool = !is_eq
        val is_lt :Bool = ea  <  eb
        val is_ge :Bool = !is_lt
        val is_ok :Bool = I.select(op,"bru.b.16",Map(
            "eq"->is_eq.asUInt,
            "ne"->is_ne.asUInt,
            "lt"->is_lt.asUInt,
            "ge"->is_ge.asUInt)
        ).asBool
        val target:UInt = base + offs
        val taken :Bool = j || is_b26 || is_b16 && is_ok
        val repeat:Bool = Reg(Bool())
        when(io.ic.allow){
            repeat := False
        }.elsewhen(io.ic.valid){
            repeat := True
        }

        val validtaken : Bool = valid && taken
        val shouldTaken :Bool = Bool()
        val shouldNotTaken :Bool = Bool()
        val prediction :Bool = Bool()

        when(validtaken){
            shouldTaken := !prediction || target =/= io.idat.predtarget
        }.otherwise{
            shouldTaken := False
        }

        when(io.ic.valid && !validtaken){
            shouldNotTaken := prediction
        }.otherwise{
            shouldNotTaken := False
        }

        io.brbus.cancel := shouldTaken && !repeat
        io.brbus.target := target
        io.brbus.idle   := idle
        io.brbus.ghr := Mux(shouldTaken, THR.update(io.idat.ghr, data.pc, target), io.idat.ghr) //restore ghr
        val roll:Bool = io.ic.empty && (csr.mayw || tlb.mayw || idle) || 
                        is_cacop ||
                        shouldNotTaken || // io.ic.allow)   //need receive this inst then cancel 
                        tlb_roll && io.tlbrw.finish //need roll after tlbrw done
        
        when(roll){
            io.brbus.cancel := True
            io.brbus.target := seqpc//Mux(need_wait_tlb && io.tlbrw.finish, io.idat.pc, seqpc)
        }
    }
    val mul = new Area{
        val valid:Bool = io.idat.itype.is_mul
        val a:UInt = io.idat.a
        val b:UInt = io.idat.b
        val ea:SInt = Mux(valid,(sign && a.msb) ## a,B(0)).asSInt
        val eb:SInt = Mux(valid,(sign && b.msb) ## b,B(0)).asSInt
        val bt = new BoothEncode(ea,eb)
        val wt = new WallaceTree(bt)
        data.mul_x := wt.res(0)
        data.mul_y := wt.res(1)
        when(valid){
            io.fw.block := True
        }
    }
    val div = new Area{
        val valid:Bool = io.idat.itype.is_div && valid_ok
        io.div.valid := valid && pipe.go
        io.div.a := Mux(valid,io.idat.a,U(0))
        io.div.b := Mux(valid,io.idat.b,U(0))
        io.div.sign := valid && io.idat.itype.sign
        when(valid){
            pipe.ready := io.div.allow
        }
    }
    val mem = new Area{
        val is_op:Bool = io.idat.itype.is_cacop && valid_ok
        val is_me:Bool = io.idat.itype.is_mem && valid_ok
        val valid:Bool = is_op || is_me
        val ld:Bool = is_me && I.isSubOpCode(op,"mem","ld")
        val st:Bool = is_me && I.isSubOpCode(op,"mem","st")
        val pr:Bool = is_me && I.isSubOpCode(op,"mem","pr")
        val sint:SInt = Mux(ll || sc,(si14@@S(0,2.bits)).resize(32),si12.resize(32))
        val base:UInt = Mux(valid, io.idat.a, U(0))
        val offs:UInt = Mux(valid, sint, S(0)).asUInt
        val va:UInt = (base + offs)
        val ld_is_w:Bool = I.isSubOpCode(op,"mem.ld","w")
        val ld_is_h:Bool = I.isSubOpCode(op,"mem.ld","h")
        val ld_is_b:Bool = I.isSubOpCode(op,"mem.ld","b")
        val st_is_w:Bool = I.isSubOpCode(op,"mem.st","w") && !(sc && !io.csrrw.llbit)
        val st_is_h:Bool = I.isSubOpCode(op,"mem.st","h")
        val st_is_b:Bool = I.isSubOpCode(op,"mem.st","b")
        val mem_is_w:Bool = ld && ld_is_w || st && st_is_w
        val mem_is_h:Bool = ld && ld_is_h || st && st_is_h
        val mem_is_b:Bool = ld && ld_is_b || st && st_is_b
        val ale:Bool = ( mem_is_w && va(1 downto 0).orR
                      || mem_is_h && va(0 downto 0).orR)

        is_cacop := is_op
        io.dreq.valid := pipe.go && valid && !pipe.hasex
        io.dreq.va  := va
        io.dreq.mem.ld := ld
        io.dreq.mem.st := st
        io.dreq.mem.pr := pr
        io.dreq.mem.w  := mem_is_w
        io.dreq.mem.h  := mem_is_h
        io.dreq.mem.b  := mem_is_b
        io.dreq.cacop.valid := pipe.go && is_op && !pipe.hasex
        io.dreq.cacop.code  := io.idat.inst(4 downto 0)

        val wd:UInt = Mux(st,io.idat.b,U(0))
        val wd_w:UInt = wd
        val wd_h:UInt = wd(15 downto 0)
        val wd_b:UInt = wd( 7 downto 0)
        val wd_1w:UInt = wd_w
        val wd_2h:UInt = wd_h @@ wd_h
        val wd_4b:UInt = wd_b @@ wd_b @@ wd_b @@ wd_b
        val wdata:UInt = ( Mux(st_is_w,wd_1w,U(0))
                         | Mux(st_is_h,wd_2h,U(0))
                         | Mux(st_is_b,wd_4b,U(0)))
        val wstrb:Bits = ( Mux(st_is_w,B("1111"),B(0))                         
                         | Mux(st_is_h,B("0011"),B(0))
                         | Mux(st_is_b,B("0001"),B(0))) |<< va(1 downto 0)
        io.dreq.wdata := wdata
        io.dreq.wstrb := wstrb
        when(valid){
            pipe.ready := Mux((ll || sc), io.dreq.allow && io.ic.empty, io.dreq.allow)
        }
        when(valid){
            io.fw.block := True
        }
    }
    need_empty := wait_empty.reduce(_||_)
//    when(need_empty && !mem.valid){
//        pipe.ready := io.ic.empty
//    }
    pipe.newex := mem.ale
    data.va := Mux(io.ic.ex,io.idat.pc,mem.va)
    when(!io.ic.ex){
        when(mem.ale){
            data.ecode := LISA.Ex.ALE.code
        }
    }

    io.btb.data.hit := io.idat.btbsave.hit
    bru.prediction := io.idat.btbsave.taken
    io.btb.data.index := io.idat.btbsave.index

    io.btb.pc := io.idat.pc
    io.btb.taken := bru.validtaken
    io.btb.target := bru.target
    io.btb.valid := io.ic.valid && io.ic.allow && (bru.valid || bru.shouldNotTaken) //let not bru but hit affect BTB
    io.btb.ghr := io.idat.ghr


    when(link_to_ra){
        io.btb.brutype := B(0)
    }.elsewhen(is_return){
        io.btb.brutype := B(3)
    }.elsewhen(bru.j || bru.is_b26){
        io.btb.brutype := B(1)
    }.otherwise{
        io.btb.brutype := B(2)
    }

    val difftest = new Area{
        val ld_w :Bool = mem.ld && mem.ld_is_w
        val ld_h :Bool = mem.ld && io.idat.itype.sign && mem.ld_is_h
        val ld_hu :Bool = mem.ld && !io.idat.itype.sign && mem.ld_is_h
        val ld_b :Bool = mem.ld && io.idat.itype.sign && mem.ld_is_b
        val ld_bu :Bool = mem.ld && !io.idat.itype.sign && mem.ld_is_b

        val st_w :Bool = mem.st && mem.st_is_w && io.dreq.valid
        val st_h :Bool = mem.st && mem.st_is_h && io.dreq.valid
        val st_b :Bool = mem.st && mem.st_is_b && io.dreq.valid


        val wstrb32 :Bits = Cat(mem.wstrb.asBools.map(_ #* 8))


        data.tv := io.csr.tv
        data.wdata := mem.wdata & wstrb32.asUInt
        data.difftest_load := ld_w##ld_hu##ld_h##ld_bu##ld_b //5 bits
        data.difftest_store := st_w##st_h##st_b //3 bits

        data.tlbfill_index := io.tlbfill_index
        data.csr_estat := io.csr_estat
    }

/*==========================Performer Counter================================*/
    val pm_mispred = UInt(32 bits) setAsReg() init(0)
    when(bru.shouldNotTaken || bru.shouldTaken){
        pm_mispred := pm_mispred + U(1)
    }
    val pm_br_count = UInt(32 bits) setAsReg() init(0)
    when(bru.valid){
        pm_br_count := pm_br_count + U(1)
    }

/*==========================Performer Counter================================*/
}
